{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FINN - How to work with ONNX\n",
    "\n",
    "This notebook should give an overview of ONNX ProtoBuf, help to create and manipulate an ONNX model and use FINN functions to work with it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Outline\n",
    "* #### How to create a simple ONNX model\n",
    "* #### How to manipulate an ONNX model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How to create a simple ONNX model\n",
    "\n",
    "To explain how to create an ONNX model a simple example with mathematical operations is used. All nodes are from the [standard operations library of ONNX](https://github.com/onnx/onnx/blob/main/docs/Operators.md).\n",
    "\n",
    "First ONNX is imported, then the helper function can be used to make a node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "from qonnx.util.basic import qonnx_make_model\n",
    "\n",
    "Add1_node = onnx.helper.make_node(\n",
    "    'Add',\n",
    "    inputs=['in1', 'in2'],\n",
    "    outputs=['sum1'],\n",
    "    name='Add1'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first attribute of the node is the operation type. In this case it is `'Add'`, so it is an adder node. Then the input names are passed to the node and at the end a name is assigned to the output.\n",
    "    \n",
    "For this example we want two other adder nodes, one abs node and the output shall be rounded so one round node is needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Add2_node = onnx.helper.make_node(\n",
    "    'Add',\n",
    "    inputs=['sum1', 'in3'],\n",
    "    outputs=['sum2'],\n",
    "    name='Add2',\n",
    ")\n",
    "\n",
    "Add3_node = onnx.helper.make_node(\n",
    "    'Add',\n",
    "    inputs=['abs1', 'abs1'],\n",
    "    outputs=['sum3'],\n",
    "    name='Add3',\n",
    ")\n",
    "\n",
    "Abs_node = onnx.helper.make_node(\n",
    "    'Abs',\n",
    "    inputs=['sum2'],\n",
    "    outputs=['abs1'],\n",
    "    name='Abs'\n",
    ")\n",
    "\n",
    "Round_node = onnx.helper.make_node(\n",
    "    'Round',\n",
    "    inputs=['sum3'],\n",
    "    outputs=['out1'],\n",
    "    name='Round',\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The names of the inputs and outputs of the nodes give already an idea of the structure of the resulting graph. In order to integrate the nodes into a graph environment, the inputs and outputs of the graph have to be specified first. In ONNX all data edges are processed as tensors. So with onnx helper function tensors value infos are created for the input and output tensors of the graph. Float from ONNX is used as data type. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in1 = onnx.helper.make_tensor_value_info(\"in1\", onnx.TensorProto.FLOAT, [4, 4])\n",
    "in2 = onnx.helper.make_tensor_value_info(\"in2\", onnx.TensorProto.FLOAT, [4, 4])\n",
    "in3 = onnx.helper.make_tensor_value_info(\"in3\", onnx.TensorProto.FLOAT, [4, 4])\n",
    "out1 = onnx.helper.make_tensor_value_info(\"out1\", onnx.TensorProto.FLOAT, [4, 4])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now the graph can be built. First all nodes are passed. Here it is to be noted that it requires a certain sequence. The nodes must be instantiated in their dependencies to each other. This means Add2 must not be listed before Add1, because Add2 depends on the result of Add1. A name is then assigned to the graph. This is followed by the inputs and outputs. \n",
    "\n",
    "`value_info` of the graph contains the remaining tensors within the graph. When creating the nodes we have already defined names for the inner data edges and now these are assigned tensors of the datatype float and a certain shape."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " graph = onnx.helper.make_graph(\n",
    "        nodes=[\n",
    "            Add1_node,\n",
    "            Add2_node,\n",
    "            Abs_node,\n",
    "            Add3_node,\n",
    "            Round_node,\n",
    "        ],\n",
    "        name=\"simple_graph\",\n",
    "        inputs=[in1, in2, in3],\n",
    "        outputs=[out1],\n",
    "        value_info=[\n",
    "            onnx.helper.make_tensor_value_info(\"sum1\", onnx.TensorProto.FLOAT, [4, 4]),\n",
    "            onnx.helper.make_tensor_value_info(\"sum2\", onnx.TensorProto.FLOAT, [4, 4]),\n",
    "            onnx.helper.make_tensor_value_info(\"abs1\", onnx.TensorProto.FLOAT, [4, 4]),\n",
    "            onnx.helper.make_tensor_value_info(\"sum3\", onnx.TensorProto.FLOAT, [4, 4]),\n",
    "        ],\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Important**: In our example, the shape of the tensors does not change during the calculation. This is not always the case. So you have to make sure that you specify the shape correctly.\n",
    "\n",
    "Now a model can be created from the graph and saved using the `.save` function. The model is saved in .onnx format and can be reloaded with `onnx.load()`. This also means that you can easily share your own model in .onnx format with others."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model = qonnx_make_model(graph, producer_name=\"simple-model\")\n",
    "onnx.save(onnx_model, '/tmp/simple_model.onnx')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To visualize the created model, [netron](https://github.com/lutzroeder/netron) can be used. Netron is a visualizer for neural network, deep learning and machine learning models. FINN provides a utility function for visualization with netron, which we import and use in the following."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.util.visualization import showInNetron"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "showInNetron('/tmp/simple_model.onnx')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Netron also allows you to interactively explore the model. If you click on a node, the node attributes will be displayed. \n",
    "\n",
    "In order to test the resulting model, a function is first written in Python that calculates the expected output. Because numpy arrays are to be used, numpy is imported first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def expected_output(in1, in2, in3):\n",
    "    sum1 = np.add(in1, in2)\n",
    "    sum2 = np.add(sum1, in3)\n",
    "    abs1 = np.absolute(sum2)\n",
    "    sum3 = np.add(abs1, abs1)\n",
    "    return np.round(sum3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then the values for the three inputs are calculated. Random numbers are used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in1_values =np.asarray(np.random.uniform(low=-5, high=5, size=(4,4)), dtype=np.float32)\n",
    "in2_values = np.asarray(np.random.uniform(low=-5, high=5, size=(4,4)), dtype=np.float32)\n",
    "in3_values = np.asarray(np.random.uniform(low=-5, high=5, size=(4,4)), dtype=np.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can easily pass the values to the function we just wrote to calculate the expected result. For the created model the inputs must be summarized in a dictionary, which is then passed on to the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dict = {}\n",
    "input_dict[\"in1\"] = in1_values\n",
    "input_dict[\"in2\"] = in2_values\n",
    "input_dict[\"in3\"] = in3_values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run the model and calculate the output, [onnxruntime](https://github.com/microsoft/onnxruntime) can be used. ONNX Runtime is a performance-focused complete scoring engine for ONNX models from Microsoft. The `.InferenceSession` function is used to create a session of the model and `.run` is used to execute the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnxruntime as rt\n",
    "\n",
    "sess = rt.InferenceSession(onnx_model.SerializeToString())\n",
    "output = sess.run(None, input_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The input values are also transferred to the reference function. Now the output of the execution of the model can be compared with that of the reference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_output= expected_output(in1_values, in2_values, in3_values)\n",
    "print(\"The output of the ONNX model is: \\n{}\".format(output[0]))\n",
    "print(\"\\nThe output of the reference function is: \\n{}\".format(ref_output))\n",
    "\n",
    "if (output[0] == ref_output).all():\n",
    "    print(\"\\nThe results are the same!\")\n",
    "else:\n",
    "    raise Exception(\"Something went wrong, the output of the model doesn't match the expected output!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have verified that the model works as we expected it to, we can continue working with the graph."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How to manipulate an ONNX model\n",
    "\n",
    "In the model there are two successive adder nodes. An adder node in ONNX can only add two inputs, but there is also the [**sum**](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sum) node, which can process more than two inputs. So it would be a reasonable change of the graph to combine the two successive adder nodes to one sum node."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the following we assume that we do not know the appearance of the model, so we first try to identify whether there are two consecutive adders in the graph and then convert them into a sum node. \n",
    "\n",
    "Here we make use of FINN. FINN provides a thin wrapper around the model which provides several additional helper functions to manipulate the graph. The so called `ModelWrapper` can be found in the QONNX repository which contains a lot of functionality that is used by FINN, you can find it [here](https://github.com/fastmachinelearning/qonnx/blob/main/src/qonnx/core/modelwrapper.py)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qonnx.core.modelwrapper import ModelWrapper\n",
    "finn_model = ModelWrapper(onnx_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As explained in the previous section, it is important that the nodes are listed in the correct order. If a new node has to be inserted or an old node has to be replaced, it is important to do that in the appropriate position. The following function serves this purpose. It returns a dictionary, which contains the node name as key and the respective node index as value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_node_id(model):\n",
    "    node_index = {}\n",
    "    node_ind = 0\n",
    "    for node in model.graph.node:\n",
    "        node_index[node.name] = node_ind\n",
    "        node_ind += 1\n",
    "    return node_index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function scans the list of nodes and stores a run index (`node_ind`) as node index in the dictionary for every node name.\n",
    "\n",
    "Another helper function is being implemented that searches for adder nodes in the graph and returns the found nodes. This is needed to determine if and which adder nodes are in the given model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def identify_adder_nodes(model):\n",
    "    add_nodes = []\n",
    "    for node in model.graph.node:\n",
    "        if node.op_type == \"Add\":\n",
    "            add_nodes.append(node)\n",
    "    return add_nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function iterates over all nodes of the model and if the operation type is `\"Add\"` the node will be stored in `add_nodes`. At the end `add_nodes` is returned.\n",
    "\n",
    "If we apply this to our model, three nodes should be returned."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "add_nodes = identify_adder_nodes(finn_model)\n",
    "for node in add_nodes:\n",
    "    print(\"Found adder node: {}\".format(node.name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Among other helper functions, `ModelWrapper` offers two functions that can help to determine the preceding and succeeding node of a node: `find_direct_successors` and `find_direct_predecessors`. So we can use one of them to define a function to find adder pairs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def adder_pair(model, node):\n",
    "    adder_pairs = []\n",
    "    node_pair = []\n",
    "    successor_list = model.find_direct_successors(node)\n",
    "    \n",
    "    for successor in successor_list:\n",
    "        if successor.op_type == \"Add\":\n",
    "            node_pair.append(node)\n",
    "            node_pair.append(successor)\n",
    "            adder_pairs.append((node_pair))\n",
    "            node_pair = []\n",
    "    return adder_pairs     "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function gets a node and the model as input. Two empty lists are created to be filled with a list of adder node pairs that can be returned as result of the function. Then the function `find_direct_successors` is used to return all of the successors of the node. If one of the successors is an adder node, the node is saved in `node_pair` together with the successive adder node and put in the list `adder_pairs`. Then the temporary list is cleaned and can be filled with the next adder node pair. Since it is theoretically possible for an adder node to have more than one subsequent adder node, a list of lists is created. This list of the node with all its successive adder nodes is returned.\n",
    "\n",
    "So now we can find out which adder node has an adder node as successor. Since the model is known, one adder pair (Add1+Add2) should be found when applying the function to the previously determined adder node list (`add_nodes`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for node in add_nodes:\n",
    "    add_pairs = adder_pair(finn_model, node)\n",
    "    if len(add_pairs) != 0:\n",
    "        for i in range(len(add_pairs)):\n",
    "            substitute_pair = add_pairs[i]\n",
    "            print(\"Found following pair that could be replaced by a sum node:\")\n",
    "            for node_pair in add_pairs:\n",
    "                for node in node_pair:\n",
    "                    print(node.name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that the pair to be replaced has been identified (`substitute_pair`), a sum node can be instantiated and inserted into the graph at the correct position. \n",
    "\n",
    "First of all, the inputs must be determined. For this the adder nodes inputs are used minus the input, which corresponds to the output of the other adder node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_list = []\n",
    "for i in range(len(substitute_pair)):\n",
    "    if i == 0:\n",
    "        for j in range(len(substitute_pair[i].input)):\n",
    "            if substitute_pair[i].input[j] != substitute_pair[i+1].output[0]:\n",
    "                input_list.append(substitute_pair[i].input[j])\n",
    "    else:\n",
    "        for j in range(len(substitute_pair[i].input)):\n",
    "            if substitute_pair[i].input[j] != substitute_pair[i-1].output[0]:\n",
    "                input_list.append(substitute_pair[i].input[j])\n",
    "print(\"The new node gets the following inputs: \\n{}\".format(input_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output of the sum node matches the output of the second adder node and can therefore be taken over directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum_output = substitute_pair[1].output[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The sum node can be created with this information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Sum_node = onnx.helper.make_node(\n",
    "    'Sum',\n",
    "    inputs=input_list,\n",
    "    outputs=[sum_output],\n",
    "    name=\"Sum\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The node can now be inserted into the graph and the old nodes are removed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_ids = get_node_id(finn_model)\n",
    "node_ind = node_ids[substitute_pair[0].name]\n",
    "graph.node.insert(node_ind, Sum_node)\n",
    "\n",
    "for node in substitute_pair:\n",
    "    graph.node.remove(node)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To insert the node in the right place, the index of the first node of the substitute_pair is used as node index for the sum node and embedded into the graph using `.insert`. Then the two elements in `substitute_pair` are deleted using `.remove`. `.insert` and `.remove` are functions provided by ONNX."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The new graph is saved as ONNX model and can be visualized with Netron."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model1 = qonnx_make_model(graph, producer_name=\"simple-model1\")\n",
    "onnx.save(onnx_model1, '/tmp/simple_model1.onnx')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "showInNetron('/tmp/simple_model1.onnx')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Through the visualization it can already be seen that the insertion was successful, but it is still to be checked whether the result remains the same. Therefore the result of the reference function written in the previous section is used and the new model with the input values is simulated. At this point onnxruntime can be used again. The simulation is analogous to the one of the first model in the previous section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = rt.InferenceSession(onnx_model1.SerializeToString())\n",
    "output = sess.run(None, input_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"The output of the manipulated ONNX model is: \\n{}\".format(output[0]))\n",
    "print(\"\\nThe output of the reference function is: \\n{}\".format(ref_output))\n",
    "\n",
    "if (output[0] == ref_output).all():\n",
    "    print(\"\\nThe results are the same!\")\n",
    "else:\n",
    "    raise Exception(\"Something went wrong, the output of the model doesn't match the expected output!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
